Skip to content

Support boolean attention mask in Attention(23) CUDA - MHA case#27428

Merged
justinchuby merged 7 commits intomainfrom
titaiwang/suuport_book_attn_mask_mha
Feb 27, 2026
Merged

Support boolean attention mask in Attention(23) CUDA - MHA case#27428
justinchuby merged 7 commits intomainfrom
titaiwang/suuport_book_attn_mask_mha

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Feb 24, 2026

Replace and reland #27129

Comparison between this PR approach and inline in softmax

Tradeoffs

Category Pre-conversion (current) Inline in softmax
Memory Extra buffer ($num_elements \times sizeof(T)$) None — reads 1-byte bool directly
Kernel launches +1 simple elementwise kernel Zero extra
Code complexity 3 files, ~40 lines added 6+ kernel templates, macros, dispatch logic, data structs
Risk Low — softmax path untested High — modifying battle-tested softmax kernels used by MHA + GQA contrib ops
Perf impact Negligible — mask is small vs. QKV; conversion is memory-bound and fast Slightly better theoretical bandwidth
Maintainability Clean separation of concerns Adds template dimension across all softmax variants

This pull request enhances the ONNX Runtime CUDA Attention operator to support boolean attention masks (bool masks) in the Multi-Head Attention (MHA) path, converting them to additive attention bias on the GPU. It also improves test coverage to ensure correctness and parity with the CPU implementation. The main changes include implementing a CUDA kernel for mask conversion, updating the operator logic to handle bool masks, clarifying broadcasting rules, and adding comprehensive unit tests.

CUDA Attention Operator Improvements:

  • Implemented a CUDA kernel (LaunchConvertBoolMaskToAttentionBias) that converts boolean attention masks to additive bias (True → 0.0, False → mask_filter_value) for the MHA path, ensuring efficient GPU execution. [1] [2]
  • Updated attention.cc to use this kernel, correctly handle bool masks in the MHA path, and clarified the broadcasting logic and mask shape interpretation for both GQA and MHA. [1] [2] [3] [4] [5]

Testing and Documentation Enhancements:

  • Added new test cases and a dedicated test class to validate the correctness of boolean mask handling in the MHA path, ensuring parity with the CPU implementation for 2D, 3D, and 4D mask shapes. [1] [2]
  • Improved comments and documentation in both code and tests to clarify ONNX broadcasting rules and mask shape expectations for different attention paths. [1] [2]

Test Coverage and Reliability:

  • Enabled CUDA-based tests for boolean mask scenarios previously only tested on CPU, and adjusted test logic to ensure correct handling of edge cases (e.g., all-false masks). [1] [2]

These changes make the CUDA Attention operator more robust and feature-complete, aligning its behavior with the CPU implementation and ONNX specifications.

@titaiwangms titaiwangms added the ep:CUDA issues related to the CUDA execution provider label Feb 24, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds CUDA support for boolean attn_mask in the ONNX Attention (opset 23) CUDA implementation (MHA path) by converting boolean masks into additive attention bias on-GPU, and enables corresponding CUDA test coverage.

Changes:

  • Added CUDA kernel + launcher to convert bool attention masks into additive bias (true -> 0, false -> mask_filter_value).
  • Updated CUDA Attention<T>::ComputeInternal to accept boolean masks and run the conversion into a scratch buffer.
  • Enabled CUDA execution for existing boolean-mask attention tests.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Enables CUDA runs for bool-mask test cases (including an all-false degenerate mask).
onnxruntime/core/providers/cuda/llm/attention_mask_impl.h Declares the new bool-mask-to-bias conversion launcher.
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu Implements the CUDA kernel and explicit instantiations for float/half/bfloat16.
onnxruntime/core/providers/cuda/llm/attention.cc Uses the new conversion path for boolean masks and aligns mask_filter_value with CPU helper.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
@titaiwangms titaiwangms requested a review from xadupre February 24, 2026 18:25
@justinchuby justinchuby self-requested a review February 26, 2026 22:47
@justinchuby
Copy link
Copy Markdown
Contributor

justinchuby commented Feb 26, 2026

🔍 Independent Review by a Team of Copilots

Review team: 2 Architects (claude-opus-4.6), 1 Code Reviewer (gemini-3-pro-preview), 1 Critical Reviewer (claude-sonnet-4.6) — coordinated by a Copilot Project Lead

Verdict: ✅ Conditional Approve

Excellent PR — clean design, great documentation, well-scoped change. One blocking issue to fix, two non-blocking suggestions.


🔴 BLOCKING: Grid Size Overflow (H1)

File: attention_mask_impl.cuLaunchConvertBoolMaskToAttentionBias

blocks is int64_t but cast to unsigned int for the CUDA <<<grid, block>>> launch. CUDA gridDim.x max is 2³¹−1. For a 4D mask on a long-context model (e.g., shape [1, 32, 128K, 128K]), blocks can exceed this limit, causing silent data corruption — only a subset of mask elements would be converted.

Suggested fix — grid-stride loop (common CUDA best practice):

__global__ void ConvertBoolMaskToAttentionBiasKernel(...) {
  for (int64_t idx = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
       idx < num_elements;
       idx += static_cast<int64_t>(gridDim.x) * blockDim.x) {
    attention_bias[idx] = attn_mask[idx] ? T(0.0f) : T(mask_filter_value);
  }
}

And cap the grid: min(blocks, 65535u) or similar.


🟡 Non-blocking suggestions

M1 — Unused include (attention.cc:6):
#include "core/providers/cpu/llm/attention.h" — no symbols from this header are used in the changed code. Creates an unnecessary CPU→CUDA provider dependency. Consider removing it, or using mask_filter_value<T>() from it instead of reimplementing with std::numeric_limits<T>::lowest() (the helper may have type-specific specializations).

M2 — Test coverage gaps (test_attention.py):
Python tests only cover fp16 with batch_size=2 and a single config. Consider adding:

  • fp32 test method (kernel is instantiated for float/half/bfloat16 but only half is tested)
  • batch_size=1 config (exercises broadcasting edge case)
  • All-true mask edge case

✅ Confirmed Safe

We investigated several potential concerns and confirmed they are not issues:

Concern Status Evidence
mask_filter_value change from -10000 to lowest() causing regression Safe Traced full call chain: value is only consumed by ComputeSoftmaxWithRawMask, which requires data.mask_index != nullptr. The new Attention op sets mask_index = nullptr, so the value is never consumed by softmax in non-bool paths.
Numerical stability / NaN from overflow Fixed Uses std::numeric_limits<T>::lowest()-65504 for fp16 (finite, not -inf). Copilot review fix correctly applied.
Buffer lifetime / use-after-free Safe converted_mask_buffer lives at function scope; all GPU work on same CUDA stream with stream-ordered allocation.
CPU/CUDA parity Achieved CPU path already uses std::numeric_limits<T>::lowest() via mask_filter_value<T>().

👏 Highlights

  • Broadcasting documentation (lines 516–527) is outstanding — the right-aligned notation with examples is best-in-class
  • Pre-conversion architecture is the right call — low risk, clean separation, negligible perf overhead vs. modifying battle-tested softmax kernels
  • GQA vs MHA asymmetry is well-documented and architecturally justified
  • Kernel quality is excellent — clean, focused, properly templated

Review performed by a team of Copilots — 4 specialist agents with different models and perspectives, coordinated by a Copilot Project Lead. Findings represent cross-validated consensus.

@titaiwangms
Copy link
Copy Markdown
Contributor Author

A 4D mask [1, 32, 128K, 128K] = 32 × 128K × 128K ≈ 537 billion elements. At 1 byte per bool, that's 537 GB — it wouldn't even fit in GPU memory. The max gridDim.x
is 2^31 - 1 ≈ 2.1 billion, and with 256 threads per block that covers ~550 billion elements, which is coincidentally close to the limit.

But practically, even a 4D mask [8, 32, 8192, 8192] (very generous) = ~17 billion elements = ~17 GB of bool, which is already well beyond what any real model
would use for an attention mask. With 256 threads/block, blocks = 17B / 256 ≈ 66M, well under the 2.1B grid limit.

That said, the fix is trivially correct and costs nothing at runtime, so it's a reasonable defensive improvement. It's up to you — both "accept and fix" (takes 2
lines) and "push back as impractical" are valid.

@justinchuby
Copy link
Copy Markdown
Contributor

A 4D mask [1, 32, 128K, 128K] = 32 × 128K × 128K ≈ 537 billion elements. At 1 byte per bool, that's 537 GB — it wouldn't even fit in GPU memory. The max gridDim.x is 2^31 - 1 ≈ 2.1 billion, and with 256 threads per block that covers ~550 billion elements, which is coincidentally close to the limit.

But practically, even a 4D mask [8, 32, 8192, 8192] (very generous) = ~17 billion elements = ~17 GB of bool, which is already well beyond what any real model would use for an attention mask. With 256 threads/block, blocks = 17B / 256 ≈ 66M, well under the 2.1B grid limit.

That said, the fix is trivially correct and costs nothing at runtime, so it's a reasonable defensive improvement. It's up to you — both "accept and fix" (takes 2 lines) and "push back as impractical" are valid.

Should have realized that 😅 AI team failed.

@justinchuby justinchuby merged commit bf71213 into main Feb 27, 2026
90 of 94 checks passed
@justinchuby justinchuby deleted the titaiwang/suuport_book_attn_mask_mha branch February 27, 2026 18:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:CUDA issues related to the CUDA execution provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants